/**
 * Copyright 2024 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import {
  ActionMetadata,
  MediaPart,
  MessageData,
  modelActionMetadata,
  z,
} from 'genkit';
import {
  getBasicUsageStats,
  modelRef,
  type GenerateRequest,
  type ModelAction,
  type ModelInfo,
  type ModelReference,
} from 'genkit/model';
import { model as pluginModel } from 'genkit/plugin';
import { imagenPredict } from './client.js';
import type {
  ClientOptions,
  GoogleAIPluginOptions,
  ImagenParameters,
  ImagenPredictRequest,
  ImagenPrediction,
  Model,
} from './types.js';
import {
  calculateApiKey,
  checkApiKey,
  checkModelName,
  extractImagenImage,
  extractText,
  extractVersion,
  modelName,
} from './utils.js';

/**
 * See https://ai.google.dev/gemini-api/docs/image-generation#imagen-model
 */
export const ImagenConfigSchema = z
  .object({
    apiKey: z
      .string()
      .describe('Override the API key provided at plugin initialization.')
      .optional(),

    numberOfImages: z
      .number()
      .describe(
        'The number of images to generate, from 1 to 4 (inclusive). The default is 1.'
      )
      .optional(),
    aspectRatio: z
      .enum(['1:1', '9:16', '16:9', '3:4', '4:3'])
      .describe('Desired aspect ratio of the output image.')
      .optional(),
    personGeneration: z
      .enum(['dont_allow', 'allow_adult', 'allow_all'])
      .describe(
        'Control if/how images of people will be generated by the model.'
      )
      .optional(),
  })
  .passthrough();
export type ImagenConfigSchemaType = typeof ImagenConfigSchema;
export type ImagenConfig = z.infer<ImagenConfigSchemaType>;

// This contains all the schemas for imagen models.
type ConfigSchemaType = ImagenConfigSchemaType;

function commonRef(
  name: string,
  info?: ModelInfo,
  configSchema: ConfigSchemaType = ImagenConfigSchema
): ModelReference<ConfigSchemaType> {
  return modelRef({
    name: `googleai/${name}`,
    configSchema,
    info: info ?? {
      supports: {
        media: true,
        multiturn: false,
        tools: false,
        toolChoice: false,
        systemRole: false,
        output: ['media'],
      },
    },
  });
}

// Allow all the capabilities for unknown future models
const GENERIC_MODEL = commonRef('imagen', {
  supports: {
    media: true,
    multiturn: true,
    tools: true,
    systemRole: true,
    output: ['media'],
  },
});

const KNOWN_MODELS = {
  'imagen-4.0-fast-generate-001': commonRef('imagen-4.0-fast-generate-001'),
  'imagen-4.0-generate-001': commonRef('imagen-4.0-generate-001'),
  'imagen-4.0-ultra-generate-001': commonRef('imagen-4.0-ultra-generate-001'),
} as const;
export type KnownModels = keyof typeof KNOWN_MODELS; // For autocomplete

// For conditional types in index.ts model()
export type ImagenModelName = `imagen-${string}`;
export function isImagenModelName(value?: string): value is ImagenModelName {
  return !!value?.startsWith('imagen-');
}

export function model(
  version: string,
  config: ImagenConfig = {}
): ModelReference<ConfigSchemaType> {
  const name = checkModelName(version);
  if (KNOWN_MODELS[name]) {
    return KNOWN_MODELS[name].withConfig(config);
  }

  return modelRef({
    name: `googleai/${name}`,
    config,
    configSchema: ImagenConfigSchema,
    info: {
      ...GENERIC_MODEL.info,
    },
  });
}

export function listActions(models: Model[]): ActionMetadata[] {
  return models
    .filter(
      (m) =>
        m.supportedGenerationMethods.includes('predict') &&
        isImagenModelName(modelName(m.name))
    )
    .filter((m) => !m.description || !m.description.includes('deprecated'))
    .map((m) => {
      const ref = model(m.name);
      return modelActionMetadata({
        name: ref.name,
        info: ref.info,
        configSchema: ref.configSchema,
      });
    });
}

export function listKnownModels(options?: GoogleAIPluginOptions) {
  return Object.keys(KNOWN_MODELS).map((name: string) =>
    defineModel(name, options)
  );
}

export function defineModel(
  name: string,
  pluginOptions?: GoogleAIPluginOptions
): ModelAction {
  checkApiKey(pluginOptions?.apiKey);
  const ref = model(name);
  const clientOptions: ClientOptions = {
    apiVersion: pluginOptions?.apiVersion,
    baseUrl: pluginOptions?.baseUrl,
  };

  return pluginModel(
    {
      name: ref.name,
      ...ref.info,
      configSchema: ref.configSchema,
    },
    async (request, { abortSignal }) => {
      const clientOpt = { ...clientOptions, signal: abortSignal };
      const imagenPredictRequest: ImagenPredictRequest = {
        instances: [
          {
            prompt: extractText(request),
            image: extractImagenImage(request),
          },
        ],
        parameters: toImagenParameters(request),
      };

      const predictApiKey = calculateApiKey(
        pluginOptions?.apiKey,
        request.config?.apiKey
      );

      const response = await imagenPredict(
        predictApiKey,
        extractVersion(ref),
        imagenPredictRequest,
        clientOpt
      );

      if (!response.predictions || response.predictions.length == 0) {
        throw new Error(
          'Model returned no predictions. Possibly due to content filters.'
        );
      }

      const message: MessageData = {
        role: 'model',
        content: response.predictions.map(fromImagenPrediction),
      };

      return {
        finishReason: 'stop',
        message,
        usage: getBasicUsageStats(request.messages, message),
        custom: response,
      };
    }
  );
}

function fromImagenPrediction(p: ImagenPrediction): MediaPart {
  const b64data = p.bytesBase64Encoded;
  const mimeType = p.mimeType;
  return {
    media: {
      url: `data:${mimeType};base64,${b64data}`,
      contentType: mimeType,
    },
  };
}

function toImagenParameters(
  request: GenerateRequest<typeof ImagenConfigSchema>
): ImagenParameters {
  const out = {
    sampleCount: request.config?.numberOfImages ?? 1,
    ...request?.config,
  };

  for (const k in out) {
    if (!out[k]) delete out[k];
  }

  // This is not part of the request parameters sent to the endpoint
  // It's pulled out and used separately
  delete out.apiKey;

  return out;
}

export const TEST_ONLY = {
  toImagenParameters,
  fromImagenPrediction,
  GENERIC_MODEL,
  KNOWN_MODELS,
};
